#include <torch/extension.h>
#include "utils.h"



torch::Tensor trilinear_interpolation(
    torch::Tensor feats,
    torch::Tensor points
){
    CHECK_INPUT(feats);
    CHECK_INPUT(points);
    
    return trilinear_fw_cu(feats, points);
}

// first is python name , second is C++ name
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
    m.def("trilinear_interpolation", &trilinear_interpolation);
}